#include "StdAfx.h"
#include "IdidEnterModels.h"

//////////////////////////////////////////////////////////////////////////
//Author	:	Ross Conroy ross.conroy@tees.ac.uk
//Date		:	09/08/2014
//
//This class is responsible for adding DID models into an I-DID using the
//following steps
//
//Loop through all model nodes in order (starting at t = 1) to add
//expand with the number of time steps
//	if t = 1 then num states = num models
//	else
//	num states = (num models)x(num states for all parents)
//
//Loop through adding observation data into aj observation nodes
//Loop through adding action data into aj action nodes
//////////////////////////////////////////////////////////////////////////

IdidEnterModels::IdidEnterModels(void)
{
}

//Domain * IdidEnterModels::EnterPprModelsIntoIdid(Domain * inputIDID, hash_map<string, PPR> pprs, string statePrefix, 
//												string utilityPrefix, string iObservationPrefix, string jObservationPrefix, 
//												string iActionPrefix, string jActionPrefix, string modelPrefix, int NumModels)
//{
//	inputIDID = ExpandModelNodesSteps(inputIDID, modelPrefix, NumModels);
//
//	ExpandTimeSteps expander;
//	int numSteps =  expander.FindLastTimeStep(inputIDID);
//
//	for(int i = 1; i <= numSteps; i++)
//	{
//		stringstream ss;
//		ss << jActionPrefix;
//		ss << i;
//		DiscreteNode * modelActionNode = (DiscreteNode *)inputIDID->getNodeByName(ss.str());
//		DiscreteNode * parent = (DiscreteNode *)modelActionNode->getParents().front();
//		
//		int tableIndex = 0;
//
//		for(int s = 0; s < parent->getNumberOfStates(); s++)
//		{
//			vector<int> ocurances (modelActionNode->getNumberOfStates());
//			//loop through each PPR to find if any match the search string
//			for (hash_map<string, PPR>::iterator itp = pprs.begin(); itp != pprs.end(); itp++)
//			{
//				PPR ppr = itp->second;
//				OcurancesOfAction ocurancesAction = ppr.GetOcurancesOfActionFromSearchString(parent->getStateLabel(s));
//				
//				//if found something then add it to the vector of ocurances
//				if(ocurancesAction.ocurances != 0)
//				{
//					int stateIndex = modelActionNode->getStateIndex(ocurancesAction.action);
//					ocurances[stateIndex] += ocurancesAction.ocurances;
//				}
//			}
//		}
//	}
//
//	return inputIDID;
//}


//////////////////////////////////////////////////////////////////////////
//Enters the DID models into the I-DID model
//For each model converts into a policy tree then adds the tree to the 
//model nodes in I-DID
//
//Builds a policy tree for each model and inserts it into the I-DID using
//similar techniques to PPR copying tree into DID
//////////////////////////////////////////////////////////////////////////
Domain * IdidEnterModels::EnterDidModelsIntoIdid(Domain * inputIDID, list<Domain *> models, string statePrefix, 
										string utilityPrefix, string iObservationPrefix, string jObservationPrefix, 
										string iActionPrefix, string jActionPrefix, string modelPrefix)
{
	vector<PolicyNode*> policyRoots;
	PolicyTreeBuilder policyTreeBuilder;

	for(list<Domain*>::iterator it = models.begin(); it != models.end(); it++)
	{
		Domain * model = *it;
		policyRoots.push_back(policyTreeBuilder.ConvertToPolicyTree(model, jActionPrefix, jObservationPrefix));
	}

	vector<string> modelNames;
	for(int i = 0; i < models.size(); i ++)
	{
		stringstream ss;
		ss << i;
		modelNames.push_back(ss.str());
	}

	inputIDID = ExpandModelNodesSteps(inputIDID, modelPrefix, models.size(), modelNames);
	inputIDID = UpdateActionNodesTables(inputIDID, policyRoots, jActionPrefix);
	inputIDID = UpdateObservationNodesTables(inputIDID, models, jObservationPrefix, models.size());


	return inputIDID;
}


//////////////////////////////////////////////////////////////////////////
//loops through each model node to set the number of model nodes for each
//for time step 1 only sets the number of models
//for other time steps set the number of states to number of parent states
//multiplied together
//////////////////////////////////////////////////////////////////////////
Domain * IdidEnterModels::ExpandModelNodesSteps(Domain * inputIDID, string ModelPrefix, int numModels, vector<string> modelNames)
{
	ExpandTimeSteps expander;
	int numSteps =  expander.FindLastTimeStep(inputIDID);

	for(int i = 1; i <= numSteps; i++)
	{
		stringstream ss;
		ss << ModelPrefix;
		ss << i;
		DiscreteNode * modelNode = (DiscreteNode *)inputIDID->getNodeByName(ss.str());
		
		//For first node ignore parents and only deal with setting number of states = number of models
		if(i == 1)
		{
			modelNode->setNumberOfStates(numModels);
			for(int j = 0; j < numModels; j++)
			{
				ss.str("");
				ss.clear();
				ss << j;
				//modelNode->setStateLabel(j, ss.str());
				modelNode->setStateLabel(j, modelNames[j]);
			}

			NumberList tableData = modelNode->getTable()->getData();
			for(int i = 0; i < tableData.size(); i++)
			{
				tableData[i] = 1;
			}
		}
		//For all other nodes need to generate combinations of parents then set the CPT to diagonal 1's
		else
		{
			//Get parent combinations and set node states to size
			CombinationGenerator generator;
			NodeList parents = modelNode->getParents();
			list<int> sizes;
			
			for(NodeList::const_iterator itp = parents.begin(); itp != parents.end(); itp++)
			{
				DiscreteNode * parent = (DiscreteNode *)*itp;
				sizes.push_back(parent->getNumberOfStates());
			}

			list<list<int>> combinations = generator.GenerateCombinations(sizes);
			list<list<int>>::iterator tmpit = combinations.end();
			tmpit--;
			list<int> combination = *tmpit;
			combinations.pop_back();
			combinations.push_front(combination);

			modelNode->setNumberOfStates(combinations.size());	


			//for each combination set state label;
			int s = 0;
			for(list<list<int>>::iterator itc = combinations.begin(); itc != combinations.end(); itc++)
			{
				list<int> combination = *itc;

				ss.str("");
				ss.clear();

				NodeList::const_iterator itp = parents.begin();
				for(list<int>::iterator iti = combination.begin(); iti!= combination.end(); iti++)
				{
					int parentState = *iti;
					DiscreteNode * parent = (DiscreteNode *)*itp;

					ss << parent->getStateLabel(parentState) << ",";

					itp++;
				}

				string stateLabel = ss.str().substr(0, ss.str().size() - 1);
				modelNode->setStateLabel(s, stateLabel);
				s++;
			}

			//set model table to all diagonal 1's
			NumberList tableData = modelNode->getTable()->getData();
			for(int j = 0; j < tableData.size(); j++)
			{
				tableData[j] = 0;
			}

			for(int j = 0; j < tableData.size(); j+= (combinations.size() + 1))
			{
				tableData[j] = 1;
			}

			modelNode->getTable()->setData(tableData);
		}
	}

	return inputIDID;
}


//////////////////////////////////////////////////////////////////////////
//Loops through all the observation nodes copying the tables from DID
//into IDID
//////////////////////////////////////////////////////////////////////////
Domain * IdidEnterModels::UpdateObservationNodesTables(Domain * inputIDID, list<Domain *> models, string jObservationPrefix, int numModels)
{
	ExpandTimeSteps expander;
	int numSteps =  expander.FindLastTimeStep(inputIDID);

	for(int t = 2; t <= numSteps; t++)
	{
		stringstream ss;
		ss << jObservationPrefix;
		ss << t;

		DiscreteNode * observationNodeIDID = (DiscreteNode *)inputIDID->getNodeByName(ss.str());

		NumberList iDIDTable = observationNodeIDID->getTable()->getData();
		int iDIDtableSize = iDIDTable.size();
		int iDIDTableIndex = 0;

		for(list<Domain *>::iterator itd = models.begin(); itd != models.end(); itd++)
		{
			Domain * curentModel = *itd;
			DiscreteNode * observationNodeDID = (DiscreteNode *)curentModel->getNodeByName(ss.str());
			NumberList didTable = observationNodeDID->getTable()->getData();
			int didTableSize = didTable.size();

			int numCopies = (iDIDtableSize/didTableSize)/numModels;

			for(int n = 0; n < numCopies; n++)
			{
				for(int it = 0; it < didTableSize; it++)
				{
					iDIDTable[iDIDTableIndex] = didTable[it];
					iDIDTableIndex++;
				}
			}
		}

		observationNodeIDID->getTable()->setData(iDIDTable);

	}


	return inputIDID;
}

//////////////////////////////////////////////////////////////////////////
//loop through each action node in DID and I-DID to add the DID's CPTs
//to the I-DID
//for each combination of parents for the current model generate a string
//of states just like for model node expansion and save in a 
//hash_set<string, int> key = combination value = CPT first row index
//for each in set find states in model parent in I-DID that ends with 
//string key and start with model index and record list of first row indexes
//////////////////////////////////////////////////////////////////////////
Domain * IdidEnterModels::UpdateActionNodesTables(Domain * inputIDID, vector<PolicyNode*> policyRoots, string jActionPrefix)
{
	ExpandTimeSteps expander;
	int numSteps =  expander.FindLastTimeStep(inputIDID);

	for(int i = 1; i <= numSteps; i++)
	{
		stringstream ss;
		ss << jActionPrefix;
		ss << i;
		DiscreteNode * modelActionNode = (DiscreteNode *)inputIDID->getNodeByName(ss.str());
		DiscreteNode * parent = (DiscreteNode *)modelActionNode->getParents().front();

		NumberList nl = modelActionNode->getTable()->getData();

		int tableIndex = 0;

		for(int s = 0; s < parent->getNumberOfStates(); s++)
		{
			vector<int> ocurances (modelActionNode->getNumberOfStates());			
			string tempParentState = parent->getStateLabel(s);
			list<string> history = CsvListToListOfStrings(tempParentState);

			//get model number and remove from list
			int modelNum = atoi(history.front().c_str());
			history.pop_front();

			//get action name by querying root node
			string actionName = policyRoots[modelNum]->GetActionForHistory(history);
			if(actionName != "")
			{
				int stateIndex = modelActionNode->getStateIndex(actionName);
				ocurances[stateIndex] = 1;
			}

			//check that all of the vector is != 0
			bool valid = false;
			{
				for(int v = 0; (v < ocurances.size()) && !valid; v++)
				{
					if(ocurances[v] > 0)
					{
						valid = true;
					}
				}
			}
			if(!valid)
			{
				for(int v = 0; (v < ocurances.size()) && !valid; v++)
				{
					ocurances[v] = 1;
				}
			}

			for(int v = 0; v < ocurances.size(); v++)
			{
				nl[tableIndex] = ocurances[v];
				tableIndex ++;
			}
		}

		modelActionNode->getTable()->setData(nl);
	}

	return inputIDID;
}


list<string> IdidEnterModels::CsvListToListOfStrings(string input)
{
	list<string> out;

	while (input.find(",", 0) != string::npos)
	{
		size_t  pos = input.find(",", 0);
		string temp = input.substr(0, pos);
		out.push_back(temp);
		input.erase(0, pos + 1);
	}
	//no trailing comma
	out.push_back(input);

	return out;
}
